/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
 * \file block_mmad_iterbatch.h
 * \brief
 */

#ifndef MATMUL_BLOCK_BLOCK_MMAD_ITERBATCH_H
#define MATMUL_BLOCK_BLOCK_MMAD_ITERBATCH_H
#include "./block_mmad.h"
#include "../../utils/layout_utils.h"
#include "../../utils/tuple_utils.h"
#include "../policy/dispatch_policy.h"
#include "../tile/load_data/tbuf_pool_l0_default.h"
#include "../tile/tile_copy.h"

namespace Act {
namespace Gemm {
namespace Block {
template <class L1TileShape, class L0TileShape, class AT, class BT, class CT, class BiasT, class TileCopy>
class BlockMmad<MatmulIterBatch<>, L1TileShape, L0TileShape, AT, BT, CT, BiasT, TileCopy,
    AscendC::Std::enable_if_t<IsMatmulLayoutTypeV<AT>>>
    : public BlockMmad<MatmulIterBatch<>, L1TileShape, L0TileShape,
        ToMatmulTypeT<AT>, ToMatmulTypeT<BT>, ToMatmulTypeT<CT>, ToMatmulTypeT<BiasT>, TileCopy> {
    using Base = BlockMmad<MatmulIterBatch<>, L1TileShape, L0TileShape,
                           ToMatmulTypeT<AT>, ToMatmulTypeT<BT>, ToMatmulTypeT<CT>, ToMatmulTypeT<BiasT>, TileCopy>;
    using Base::Base;
};

template <class L1TileShape_, class L0TileShape_, class AType_, class BType_, class CType_, class BiasType_,
          class TileCopy_>
class BlockMmad<MatmulIterBatch<>, L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, TileCopy_,
    AscendC::Std::enable_if_t<!IsMatmulLayoutTypeV<AType_>>> {
public:
    using AType = AType_;
    using BType = BType_;
    using CType = CType_;
    using A_T = typename AType::T;
    using B_T = typename BType::T;
    using C_T = typename CType::T;
    using DispatchPolicy = MatmulIterBatch<>;
    using TupleShape = AscendC::Shape<int64_t, int64_t, int64_t, int64_t>;
    uint64_t m_;
    uint64_t n_;
    uint64_t k_;
    uint64_t alignedM_{1};
    uint64_t alignedN_{1};
    uint64_t alignedK_{1};
    constexpr static uint64_t BUFFER_NUM = 2;
    uint64_t abL1EventID_{0};
    uint64_t l0EventID_{0};
    uint64_t l0CEventID_{0};
    uint64_t l0AOffset_ = L0A_SIZE / BUFFER_NUM / sizeof(A_T);
    uint64_t l0BOffset_ = L0B_SIZE / BUFFER_NUM / sizeof(B_T);
    uint64_t l0COffset_ = L0C_SIZE / BUFFER_NUM / sizeof(float);

    __aicore__ inline BlockMmad()
    {
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(ZERO_FLAG);
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(FIRST_FLAG);
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(SECOND_FLAG);
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(THIRD_FLAG);
        AscendC::SetFlag<AscendC::HardEvent::FIX_M>(ZERO_FLAG);
        AscendC::SetFlag<AscendC::HardEvent::FIX_M>(FIRST_FLAG);
    }

    __aicore__ inline ~BlockMmad()
    {
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(ZERO_FLAG);
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(FIRST_FLAG);
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(SECOND_FLAG);
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(THIRD_FLAG);
        AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(ZERO_FLAG);
        AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(FIRST_FLAG);
    }

public:
    __aicore__ inline void Init(const TupleShape& shape)
    {
        m_ = Get<DIMENSION_M>(shape);
        n_ = Get<DIMENSION_N>(shape);
        k_ = Get<DIMENSION_K>(shape);
        // when fp16 or (fp32 and m,k), m align to 16; when fp32 and k,m, m align to 8 * 2 for frac combine in loadtol0a
        alignedM_ = CeilAlign(m_, AscendC::BLOCK_CUBE);
        // when fp16 or (fp32 and n,k), m align to 16; when fp32 and k,n, n align to 8 * 2 for frac combine in loadtol0b
        alignedN_ = CeilAlign(n_, AscendC::BLOCK_CUBE);
        alignedK_ = CeilAlign(k_, AscendC::BLOCK_CUBE);
        l0EventID_ = 0;
        abL1EventID_ = 0;
    }
    __aicore__ inline void CopyInA1(const AscendC::GlobalTensor<A_T>& aGlobal,
                                    const AscendC::LocalTensor<A_T>& al1Local, const uint64_t curIterBatchL1,
                                    const uint64_t mInGM, const uint64_t kaInGM, const uint64_t mInL1A,
                                    const uint64_t kaInL1A)
    {
        AscendC::Nd2NzParams nd2nzParams;
        nd2nzParams.ndNum = curIterBatchL1; // represent how many matrices load to l1, so it is iterbatchl1 num.
        nd2nzParams.nValue = AType::isTrans ? kaInGM : mInGM;          // source N size
        nd2nzParams.dValue = AType::isTrans ? mInGM : kaInGM;          // source D size
        nd2nzParams.srcNdMatrixStride = mInGM * kaInGM;                // source gap of one block
        nd2nzParams.srcDValue = AType::isTrans ? mInGM : kaInGM;       // source D gap of one block
        nd2nzParams.dstNzC0Stride = AType::isTrans ? kaInL1A : mInL1A; // dst gap of one fractal nz block
        nd2nzParams.dstNzNStride = 1;                                  // dst N gap of one block, unit of matrix
        nd2nzParams.dstNzMatrixStride = mInL1A * kaInL1A;
        AscendC::DataCopy(al1Local, aGlobal, nd2nzParams);
    }

    __aicore__ inline void CopyInB1(const AscendC::GlobalTensor<B_T>& bGlobal,
                                    const AscendC::LocalTensor<B_T>& bl1Local, const uint64_t curIterBatchL1,
                                    const uint64_t kbInGM, const uint64_t nInGM, const uint64_t kbInL1B,
                                    const uint64_t nInL1B)
    {
        AscendC::Nd2NzParams nd2nzParams;
        nd2nzParams.ndNum = curIterBatchL1;
        uint64_t nDim = BType::isTrans ? nInGM : kbInGM;
        uint64_t dDim = BType::isTrans ? kbInGM : nInGM;

        nd2nzParams.nValue = BType::isTrans ? nInGM : kbInGM;
        nd2nzParams.dValue = BType::isTrans ? kbInGM : nInGM;
        nd2nzParams.srcNdMatrixStride = nInGM * kbInGM;
        nd2nzParams.srcDValue = BType::isTrans ? kbInGM : nInGM;
        nd2nzParams.dstNzC0Stride = BType::isTrans ? nInL1B : kbInL1B;
        nd2nzParams.dstNzNStride = 1;
        nd2nzParams.dstNzMatrixStride = nInL1B * kbInL1B;
        AscendC::DataCopy(bl1Local, bGlobal, nd2nzParams);
    }

    __aicore__ inline void CopyInA2(const AscendC::LocalTensor<A_T>& a2Local, const AscendC::LocalTensor<A_T>& al1Local,
        uint64_t kaInL0A, uint64_t mInL0A, uint64_t kaInL1A, uint64_t mInL1A, uint64_t curIterBatchL0)
    {
        if constexpr (!AType::isTrans) {
            AscendC::LoadData2DParamsV2 loadDataParams;
            loadDataParams.mStartPosition = 0;
            loadDataParams.kStartPosition = 0;
            loadDataParams.mStep = CeilDiv(mInL0A, AscendC::BLOCK_CUBE);
            loadDataParams.kStep = CeilDiv(kaInL0A, AscendC::AuxGetC0Size<A_T>());
            loadDataParams.srcStride = CeilDiv(mInL1A, AscendC::BLOCK_CUBE);
            loadDataParams.dstStride = loadDataParams.mStep;
            loadDataParams.ifTranspose = false;
            for (uint64_t iterL0AIndex = 0; iterL0AIndex < curIterBatchL0; iterL0AIndex++) {
                if constexpr (AscendC::IsSameType<A_T, bfloat16_t>::value) {
                    AscendC::LoadData(a2Local[iterL0AIndex * mInL1A * kaInL1A],
                                      al1Local[iterL0AIndex * mInL1A * kaInL1A], loadDataParams);
                } else {
                    AscendC::LoadData<A_T>(a2Local[iterL0AIndex * mInL1A * kaInL1A],
                                           al1Local[iterL0AIndex * mInL1A * kaInL1A], loadDataParams);
                }
            }
        } else {
            AscendC::LoadData2DParamsV2 loadDataParams;
            loadDataParams.mStartPosition = 0;
            loadDataParams.kStartPosition = 0;
            loadDataParams.mStep = CeilDiv(kaInL0A, AscendC::BLOCK_CUBE);
            loadDataParams.kStep = CeilDiv(CeilAlign(mInL0A, AscendC::BLOCK_CUBE), AscendC::AuxGetC0Size<A_T>());
            loadDataParams.srcStride = CeilDiv(kaInL1A, AscendC::BLOCK_CUBE);
            loadDataParams.dstStride = CeilDiv(mInL0A, AscendC::BLOCK_CUBE);
            loadDataParams.ifTranspose = true;
            for (uint64_t iterL0AIndex = 0; iterL0AIndex < curIterBatchL0; iterL0AIndex++) {
                if constexpr (AscendC::IsSameType<A_T, bfloat16_t>::value) {
                    AscendC::LoadData(a2Local[iterL0AIndex * mInL1A * kaInL1A],
                                      al1Local[iterL0AIndex * mInL1A * kaInL1A], loadDataParams);
                } else {
                    AscendC::LoadData<A_T>(a2Local[iterL0AIndex * mInL1A * kaInL1A],
                                           al1Local[iterL0AIndex * mInL1A * kaInL1A], loadDataParams);
                }
            }
        }
    }

    __aicore__ inline void CopyInB2(const AscendC::LocalTensor<B_T>& b2Local, const AscendC::LocalTensor<B_T>& bl1Local,
        uint64_t kbInL0B, uint64_t nInL0B, uint64_t kbInL1B, uint64_t nInL1B, uint64_t curIterBatchL0)
    {
        if constexpr (BType::isTrans) {
            AscendC::LoadData2DParamsV2 loadDataParams;
            loadDataParams.mStartPosition = 0;
            loadDataParams.kStartPosition = 0;
            loadDataParams.mStep = CeilDiv(nInL0B, AscendC::BLOCK_CUBE);
            loadDataParams.kStep = CeilDiv(kbInL0B, AscendC::AuxGetC0Size<B_T>());
            loadDataParams.srcStride = CeilDiv(nInL1B, AscendC::BLOCK_CUBE);
            loadDataParams.dstStride = loadDataParams.mStep;
            loadDataParams.ifTranspose = false;
            for (uint64_t iterL0BIndex = 0; iterL0BIndex < curIterBatchL0; iterL0BIndex++) {
                if constexpr (AscendC::IsSameType<B_T, bfloat16_t>::value) {
                    AscendC::LoadData(b2Local[iterL0BIndex * kbInL1B * nInL1B],
                                      bl1Local[iterL0BIndex * kbInL1B * nInL1B], loadDataParams);
                } else {
                    AscendC::LoadData<B_T>(b2Local[iterL0BIndex * kbInL1B * nInL1B],
                                           bl1Local[iterL0BIndex * kbInL1B * nInL1B], loadDataParams);
                }
            }
        } else {
            AscendC::LoadData2DParamsV2 loadDataParams;
            loadDataParams.mStartPosition = 0;
            loadDataParams.kStartPosition = 0;
            loadDataParams.mStep = CeilDiv(kbInL0B, AscendC::BLOCK_CUBE);
            loadDataParams.kStep = CeilDiv(CeilAlign(nInL0B, AscendC::BLOCK_CUBE), AscendC::AuxGetC0Size<B_T>());
            loadDataParams.srcStride = CeilDiv(kbInL1B, AscendC::BLOCK_CUBE);
            loadDataParams.dstStride = CeilDiv(nInL0B, AscendC::BLOCK_CUBE);
            loadDataParams.ifTranspose = true;
            for (uint64_t iterL0BIndex = 0; iterL0BIndex < curIterBatchL0; iterL0BIndex++) {
                if constexpr (AscendC::IsSameType<B_T, bfloat16_t>::value) {
                    AscendC::LoadData(b2Local[iterL0BIndex * kbInL1B * nInL1B],
                                      bl1Local[iterL0BIndex * kbInL1B * nInL1B], loadDataParams);
                } else {
                    AscendC::LoadData<B_T>(b2Local[iterL0BIndex * kbInL1B * nInL1B],
                                           bl1Local[iterL0BIndex * kbInL1B * nInL1B], loadDataParams);
                }
            }
        }
    }

    __aicore__ inline void Mmad(const AscendC::LocalTensor<A_T>& l0a,
                                const AscendC::LocalTensor<B_T>& l0b,
                                const AscendC::LocalTensor<float>& l0c,
                                const uint64_t mInGM, const uint64_t nInGM, const uint64_t kInGM,
                                const uint64_t mInL0a, const uint64_t kaInL0a,
                                const uint64_t kbInL0b, const uint64_t nInL0b,
                                const uint64_t mInL0c, const uint64_t nInL0c,
                                const uint64_t curIterBatchL0, const bool cmatrixInitVal)
    {
        AscendC::MmadParams mmadParams;
        mmadParams.m = mInGM;
        mmadParams.n = nInGM;
        mmadParams.k = kInGM;
        mmadParams.unitFlag = 0; // each l0 only process one block, disable unit flag.
        mmadParams.cmatrixInitVal = cmatrixInitVal;
        mmadParams.disableGemv = true; // disable gemv when m equals 1, which is not capable.
        for (uint64_t iterL0CIndex = 0; iterL0CIndex < curIterBatchL0; iterL0CIndex++) {
            AscendC::Mmad(l0c[iterL0CIndex * mInL0c * nInL0c],
                          l0a[iterL0CIndex * mInL0a * kaInL0a],
                          l0b[iterL0CIndex * kbInL0b * nInL0b], mmadParams);
        }
    }

    __aicore__ inline void CopyOut(const AscendC::GlobalTensor<C_T>& cGlobal, const AscendC::LocalTensor<float>& l0c,
                                   const uint64_t mInGM, const uint64_t nInGM, const uint64_t curIterBatchL0)
    {
        AscendC::DataCopyCO12DstParams intriParams;
        intriParams.nSize = nInGM;
        intriParams.mSize = mInGM;
        intriParams.dstStride = n_;
        intriParams.srcStride = Align(mInGM, AscendC::BLOCK_CUBE);
        if constexpr (AscendC::IsSameType<C_T, bfloat16_t>::value) {
            intriParams.quantPre = QuantMode_t::F322BF16;
        } else if (AscendC::IsSameType<C_T, half>::value) {
            intriParams.quantPre = QuantMode_t::F322F16;
        }
        intriParams.nz2ndEn = true;
        intriParams.unitFlag = 0;

        // When nz2nd loop in copyout, src stride is unit of c0Size, dst stride is unit of one element.
        AscendC::SetFixpipeNz2ndFlag(curIterBatchL0, Align(mInGM, AscendC::BLOCK_CUBE) *
                                     Align(nInGM, AscendC::BLOCK_CUBE) / AscendC::BLOCK_CUBE, mInGM * nInGM);
        AscendC::DataCopy(cGlobal, l0c, intriParams);
    }

    __aicore__ inline void operator()(AscendC::GlobalTensor<C_T> cGlobal,
                                      AscendC::GlobalTensor<A_T> aGlobal,
                                      AscendC::GlobalTensor<B_T> bGlobal,
                                      uint64_t blockNum,
                                      uint64_t curIterBatchL1,
                                      uint64_t nextIterBatchL1,
                                      uint64_t mainIterBatchL1,
                                      uint64_t mainIterBatchL0,
                                      uint64_t baseM,
                                      uint64_t baseN,
                                      uint64_t baseK,
                                      bool isPreLoadRound,
                                      bool isFinalRound)
    {
        AscendC::LocalTensor<A_T> al1Local(AscendC::TPosition::A1, 0, L1_SIZE); // start of l1
        AscendC::LocalTensor<B_T> bl1Local = al1Local[alignedM_ * alignedK_ * BUFFER_NUM * mainIterBatchL1];
        AscendC::LocalTensor<A_T> l0a(AscendC::TPosition::A2, 0, L0A_SIZE);
        AscendC::LocalTensor<B_T> l0b(AscendC::TPosition::B2, 0, L0B_SIZE);
        AscendC::LocalTensor<float> l0c(AscendC::TPosition::CO1, 0, L0C_SIZE);

        // mov align to L1 with pingpong
        if (isPreLoadRound) { // first round, copy first loop of data
            AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(abL1EventID_ & 0x1); // wait last loop mte1 to finish
            CopyInA1(aGlobal, al1Local[alignedM_ * alignedK_ * mainIterBatchL1 * (abL1EventID_ & 0x1)],
                     curIterBatchL1, m_, k_, alignedM_, alignedK_);
            AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(abL1EventID_ & 0x1); // set current loop mte1 to wait

            AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>((abL1EventID_ & 0x1) + L1_EVENT_ID_OFFSET);
            CopyInB1(bGlobal, bl1Local[alignedN_ * alignedK_ * mainIterBatchL1 * (abL1EventID_ & 0x1)],
                     curIterBatchL1, k_, n_, alignedK_, alignedN_);
            AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>((abL1EventID_ & 0x1) + L1_EVENT_ID_OFFSET);
        }
        if (!isFinalRound) { // before last round need to precopy next loop of data
            AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>((abL1EventID_ + 1) & 0x1); // wait last loop mte1 to finish
            CopyInA1(aGlobal[m_ * k_ * mainIterBatchL1 * blockNum], al1Local[alignedM_ * alignedK_ *
                     mainIterBatchL1 * ((abL1EventID_ + 1) & 0x1)], nextIterBatchL1, m_, k_, alignedM_, alignedK_);
            AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>((abL1EventID_ + 1) & 0x1); // set current loop mte1 to wait

            AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(((abL1EventID_ + 1) & 0x1) + L1_EVENT_ID_OFFSET);
            CopyInB1(bGlobal[k_ * n_ * mainIterBatchL1 * blockNum], bl1Local[alignedN_ * alignedK_ *
                     mainIterBatchL1 * ((abL1EventID_ + 1) & 0x1)], nextIterBatchL1, k_, n_, alignedK_, alignedN_);
            AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(((abL1EventID_ + 1) & 0x1) + L1_EVENT_ID_OFFSET);
        }

        uint64_t mL0Cnt = CeilDiv(m_, baseM);
        uint64_t nL0Cnt = CeilDiv(n_, baseN);
        uint64_t kL0Cnt = CeilDiv(k_, baseK);

        // calculate how much loop needed between l1 and l0.
        uint64_t stepIterBatchL1L0 = CeilDiv(curIterBatchL1, mainIterBatchL0);
        for (uint64_t iter1 = 0; iter1 < stepIterBatchL1L0; ++iter1) {
            uint64_t curIterBatchL0 = (iter1 + 1 == stepIterBatchL1L0) ? // if tailloop of l1 and l0, cal tail iter num.
                                      (curIterBatchL1 - mainIterBatchL0 * iter1) : mainIterBatchL0;
            for (uint64_t iterNL0 = 0; iterNL0 < nL0Cnt; ++iterNL0) {
                uint64_t curNL0 = (iterNL0 == nL0Cnt - 1) ? (n_ - (nL0Cnt - 1) * baseN) : baseN;
                for (uint64_t iterML0 = 0; iterML0 < mL0Cnt; ++iterML0) {
                    uint64_t curML0 = (iterML0 == mL0Cnt - 1) ? (m_ - (mL0Cnt - 1) * baseM) : baseM;
                    for (uint64_t iterKL0 = 0; iterKL0 < kL0Cnt; ++iterKL0) {
                        uint64_t curKL0 = (iterKL0 == kL0Cnt - 1) ? (k_ - (kL0Cnt - 1) * baseK) : baseK;
                        if (iter1 == 0 && iterNL0 == 0 && iterML0 == 0 && iterKL0 == 0) {
                            AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(abL1EventID_ & 0x1);
                        }
                        AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0EventID_ & 0x1);
                        uint64_t offsetL1AOfCopyInA2 = alignedM_ * alignedK_ * mainIterBatchL1 * (abL1EventID_ & 0x1) +
                                                       iter1 * mainIterBatchL0 * alignedM_ * alignedK_ +
                                                       (AType::isTrans ? (iterML0 * alignedK_ * baseM +
                                                       iterKL0 * baseK * AscendC::AuxGetC0Size<A_T>()) :
                                                       (iterML0 * AscendC::AuxGetC0Size<A_T>() * baseM +
                                                       iterKL0 * alignedM_ * baseK));
                        CopyInA2(l0a[l0AOffset_ * (l0EventID_ & 0x1)], al1Local[offsetL1AOfCopyInA2],
                                 curKL0, curML0, alignedK_, alignedM_, curIterBatchL0);
                        if ((iter1 == stepIterBatchL1L0 - 1) && (iterNL0 == nL0Cnt - 1) && (iterML0 == mL0Cnt - 1) &&
                             (iterKL0 == kL0Cnt - 1)) {
                            // after last loop, notice Mte2 to wait Mte1
                            AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(abL1EventID_ & 0x1);
                        }

                        if (iter1 == 0 && iterNL0 == 0 && iterML0 == 0 && iterKL0 == 0) {
                            AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>((abL1EventID_ & 0x1) + L1_EVENT_ID_OFFSET);
                        }
                        uint64_t offsetL1BOfCopyInB2 = alignedN_ * alignedK_ * mainIterBatchL1 * (abL1EventID_ & 0x1) +
                                                       iter1 * mainIterBatchL0 * alignedK_ * alignedN_ +
                                                       (BType::isTrans ?
                                                       (iterNL0 * AscendC::AuxGetC0Size<B_T>() * baseN +
                                                       iterKL0 * baseK * alignedN_) :
                                                       (iterNL0 * alignedK_ * baseN +
                                                       iterKL0 * baseK * AscendC::AuxGetC0Size<B_T>()));
                        CopyInB2(l0b[l0BOffset_ * (l0EventID_ & 0x1)], bl1Local[offsetL1BOfCopyInB2],
                                 curKL0, curNL0, alignedK_, alignedN_, curIterBatchL0);
                        AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(l0EventID_ & 0x1);
                        if ((iter1 == stepIterBatchL1L0 - 1) && (iterNL0 == nL0Cnt - 1) && (iterML0 == mL0Cnt - 1) &&
                            (iterKL0 == kL0Cnt - 1)) {
                            AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>((abL1EventID_ & 0x1) + L1_EVENT_ID_OFFSET);
                        }

                        AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(l0EventID_ & 0x1);
                        if (iterKL0 == 0) {
                            AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(l0CEventID_ & 0x1);
                        }
                        bool cmatrixInitVal = (iterKL0 == 0);
                        Mmad(l0a[l0AOffset_ * (l0EventID_ & 0x1)],
                             l0b[l0BOffset_ * (l0EventID_ & 0x1)],
                             l0c[l0COffset_ * (l0CEventID_ & 0x1)],
                             curML0, curNL0, curKL0, alignedM_, alignedK_, alignedK_, alignedN_, alignedM_, alignedN_,
                             curIterBatchL0, cmatrixInitVal);
                        AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0EventID_ & 0x1);
                        l0EventID_++;
                    }
                    AscendC::SetFlag<AscendC::HardEvent::M_FIX>(l0CEventID_ & 0x1);

                    AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(l0CEventID_ & 0x1);
                    uint64_t offsetCGMOfCopyOut = iter1 * mainIterBatchL0 * m_ * n_ + iterML0 * baseM * n_ +
                                                  iterNL0 * baseN;
                    CopyOut(cGlobal[offsetCGMOfCopyOut], l0c[l0COffset_ * (l0CEventID_ & 0x1)], curML0, curNL0,
                            curIterBatchL0);
                    AscendC::SetFlag<AscendC::HardEvent::FIX_M>(l0CEventID_ & 0x1);
                    l0CEventID_++;
                }
            }
        }
        abL1EventID_++;
    }

private:
    constexpr static uint16_t L1_EVENT_ID_OFFSET = 2;
    constexpr static uint16_t DIMENSION_M = 0;
    constexpr static uint16_t DIMENSION_N = 1;
    constexpr static uint16_t DIMENSION_K = 2;
    constexpr static uint16_t ZERO_FLAG = 0;
    constexpr static uint16_t FIRST_FLAG = 1;
    constexpr static uint16_t SECOND_FLAG = 2;
    constexpr static uint16_t THIRD_FLAG = 3;
};
} // namespace Block
} // namespace Gemm
} // namespace Act
#endif